package com.ezlcp.gateway.filter;

import com.ezlcp.gateway.utils.XssCleanRuleUtils;
import org.reactivestreams.Publisher;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.nio.charset.StandardCharsets;

/**
 * @author Elwin ZHANG
 * @description: 跨域脚本攻击响应参数全局过滤器<br />
 * 针对系统已经存在XSS注入时的过滤器，如果系统是干净的可以不启用
 * @date 2023/2/1 15:37
 */
//先不使用 @Component
public class XssResponseGlobalFilter implements GlobalFilter, Ordered {

    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        ServerHttpResponse originalResponse = exchange.getResponse();
        DataBufferFactory bufferFactory = originalResponse.bufferFactory();
        ServerHttpResponseDecorator decoratedResponse = new ServerHttpResponseDecorator(originalResponse) {
            @Override
            public Mono<Void> writeWith(Publisher<? extends DataBuffer> body) {
                if (body instanceof Flux) {
                    Flux<? extends DataBuffer> fluxBody = (Flux<? extends DataBuffer>) body;
                    return super.writeWith(fluxBody.buffer().map(dataBuffer -> {
                        DataBufferFactory dataBufferFactory = new DefaultDataBufferFactory();
                        DataBuffer join = dataBufferFactory.join(dataBuffer);
                        try {
                            // body重写
                            byte[] content = new byte[join.readableByteCount()];
                            join.read(content);
                            String rawResponse = new String(content, StandardCharsets.UTF_8);
                            String newResponseStr = XssCleanRuleUtils.xssClean(rawResponse);
                            byte[] newResponseBytes = newResponseStr.getBytes(StandardCharsets.UTF_8);

                            // header重写
                            HttpHeaders httpHeaders = originalResponse.getHeaders();
                            httpHeaders.remove(HttpHeaders.CONTENT_LENGTH);
                            httpHeaders.setContentLength(newResponseBytes.length);

                            return bufferFactory.wrap(newResponseBytes);
                        } finally {
                            // DataBuffer看上去几乎等于Netty的ByteBuff，一定要保障显式释放，否则直接内存溢出!
                            DataBufferUtils.release(join);
                        }
                    }));
                }
                return super.writeWith(body);
            }
        };

        return chain.filter(exchange.mutate().response(decoratedResponse).build());
    }

    @Override
    public int getOrder() {
        return Ordered.HIGHEST_PRECEDENCE + 1000;
    }
}